{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 30,
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "import pandas as pd\n",
    "\n",
    "import torch\n",
    "import torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "from torch.utils.data import TensorDataset, DataLoader\n",
    "\n",
    "import time\n",
    "import tqdm\n",
    "import itertools\n",
    "from allennlp.nn import util\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 27,
   "metadata": {},
   "outputs": [],
   "source": [
    "from chinese_gpt import TransformerEncoder as Encoder\n",
    "from chinese_gpt import TransformerDecoderLM as Decoder\n",
    "from pytorch_pretrained_bert import BertModel, BertTokenizer, OpenAIAdam"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "encoder = Encoder()\n",
    "decoder = Decoder()\n",
    "\n",
    "encoder.load_state_dict(torch.load(\"encoder.pth\"))\n",
    "decoder.load_state_dict(torch.load(\"decoder.pth\"))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "train_data = torch.load(\"train_data.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "metadata": {},
   "outputs": [],
   "source": [
    "batch_size = 16\n",
    "train_dataset = TensorDataset(*train_data)\n",
    "train_dataloader = DataLoader(dataset=train_dataset, shuffle=True, batch_size=batch_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda\")\n",
    "\n",
    "encoder = encoder.to(device)\n",
    "decoder = decoder.to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "num_epochs = 10\n",
    "num_gradients_accumulation = 4\n",
    "num_train_optimization_steps = num_train_optimization_steps = len(train_dataset) * num_epochs // batch_size // num_gradients_accumulation\n",
    "\n",
    "param_optimizer = list(encoder.named_parameters()) + list(decoder.named_parameters()) \n",
    "no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']\n",
    "optimizer_grouped_parameters = [\n",
    "    {'params': [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)], 'weight_decay': 0.01},\n",
    "    {'params': [p for n, p in param_optimizer if any(nd in n for nd in no_decay)], 'weight_decay': 0.0}\n",
    "    ]\n",
    "\n",
    "\n",
    "optimizer = OpenAIAdam(optimizer_grouped_parameters,\n",
    "                       lr=1e-5,\n",
    "                       warmup=0.01,\n",
    "                       max_grad_norm=1.0,\n",
    "                       weight_decay=0.01,\n",
    "                       t_total=num_train_optimization_steps)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "71c2c0d23c6348279c9958b7a296efb3",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "b69cc7bf99a14c5bb6bc077546cd3a8d",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a1b7e887842a459c97c8291ecdfe77d6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dfe68e71dbdb4ca385aeec95e9e407cc",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "01d8745983044e86b08001cd26ed3aa6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6f59187f65b34fd785073ba00095d383",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7d52ee9567214470964dc3694ee7ae53",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3f1949a94bb5466db2b0b211f91d1cdf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7cf86a4964674dc0af8006c1d2cac647",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "e24cf368e65e400dbcb475cc5111cea8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "HBox(children=(IntProgress(value=0, max=1280), HTML(value='')))"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\n"
     ]
    }
   ],
   "source": [
    "update_count = 0\n",
    "start = time.time()\n",
    "\n",
    "for ep in range(num_epochs):\n",
    "    \n",
    "    pb = tqdm.tqdm_notebook(train_dataloader)\n",
    "    \n",
    "    for batch in pb:\n",
    "        batch = [item.to(device) for item in batch]\n",
    "\n",
    "        encoder_input, \\\n",
    "                third, \\\n",
    "                mask_encoder_input, \\\n",
    "                mask_third, \\\n",
    "                encoder_type_ids, \\\n",
    "                third_type_ids = batch\n",
    "        \n",
    "        _, past = encoder(encoder_input, mask_encoder_input, encoder_type_ids)\n",
    "    \n",
    "        mask = torch.cat([mask_encoder_input, mask_third], dim=1)\n",
    "        logits, _ = decoder(third, mask, past=past, past_length=0)\n",
    "        \n",
    "        out = logits[:, :-1].contiguous()\n",
    "        target = third[:, 1:].contiguous()\n",
    "        target_mask = mask_third[:, 1:].contiguous()\n",
    "\n",
    "        loss = util.sequence_cross_entropy_with_logits(out, target, target_mask, average=\"token\")\n",
    "        loss.backward()\n",
    "        \n",
    "        update_count += 1\n",
    "\n",
    "        if update_count % num_gradients_accumulation == num_gradients_accumulation - 1:\n",
    "            optimizer.step()\n",
    "            optimizer.zero_grad()\n",
    "            # speed measure\n",
    "            end = time.time()\n",
    "            speed = batch_size * num_gradients_accumulation / (end - start)\n",
    "            start = end\n",
    "            record_loss = loss.item()\n",
    "            perplexity = np.exp(record_loss)\n",
    "\n",
    "            pb.set_postfix(loss=record_loss, perplexity=perplexity, speed=speed)\n",
    "    \n",
    "    torch.save(encoder.state_dict(), str(ep)+\"encoder.pth\")\n",
    "    torch.save(decoder.state_dict(), str(ep)+\"decoder.pth\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "metadata": {},
   "outputs": [],
   "source": [
    "torch.save(encoder.state_dict(), str(ep)+\"encoder.pth\")\n",
    "torch.save(decoder.state_dict(), str(ep)+\"decoder.pth\")third_type_ids"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.1"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
